# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

import unittest

import dns.exception
import dns.flags
import dns.message
import dns.renderer
import dns.tsig
import dns.tsigkeyring

basic_answer = """flags QR
edns 0
payload 4096
;QUESTION
foo.example. IN A
;ANSWER
foo.example. 30 IN A 10.0.0.1
foo.example. 30 IN A 10.0.0.2
"""


class RendererTestCase(unittest.TestCase):
    def test_basic(self):
        r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
        qname = dns.name.from_text("foo.example")
        r.add_question(qname, dns.rdatatype.A)
        rds = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2")
        r.add_rdataset(dns.renderer.ANSWER, qname, rds)
        r.add_edns(0, 0, 4096)
        r.write_header()
        wire = r.get_wire()
        message = dns.message.from_wire(wire)
        expected = dns.message.from_text(basic_answer)
        # Our rendered message purposely has a random query id so we
        # exercise that code, so copy it into the expected message.
        expected.id = message.id
        self.assertEqual(message, expected)

    def test_tsig(self):
        r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
        qname = dns.name.from_text("foo.example")
        r.add_question(qname, dns.rdatatype.A)
        keyring = dns.tsigkeyring.from_text({"key": "12345678"})
        keyname = next(iter(keyring))
        r.write_header()
        r.add_tsig(
            keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256
        )
        wire = r.get_wire()
        message = dns.message.from_wire(wire, keyring=keyring)
        expected = dns.message.make_query(qname, dns.rdatatype.A)
        expected.id = message.id
        self.assertEqual(message, expected)

    def test_multi_tsig(self):
        qname = dns.name.from_text("foo.example")
        keyring = dns.tsigkeyring.from_text({"key": "12345678"})
        keyname = next(iter(keyring))

        r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
        r.add_question(qname, dns.rdatatype.A)
        r.write_header()
        ctx = r.add_multi_tsig(
            None,
            keyname,
            keyring[keyname],
            300,
            r.id,
            0,
            b"",
            b"",
            dns.tsig.HMAC_SHA256,
        )
        wire = r.get_wire()
        message = dns.message.from_wire(wire, keyring=keyring, multi=True)
        expected = dns.message.make_query(qname, dns.rdatatype.A)
        expected.id = message.id
        self.assertEqual(message, expected)

        r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
        r.add_question(qname, dns.rdatatype.A)
        r.write_header()
        ctx = r.add_multi_tsig(
            ctx, keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256
        )
        wire = r.get_wire()
        message = dns.message.from_wire(
            wire, keyring=keyring, tsig_ctx=message.tsig_ctx, multi=True
        )
        expected = dns.message.make_query(qname, dns.rdatatype.A)
        expected.id = message.id
        self.assertEqual(message, expected)

    def test_going_backwards_fails(self):
        r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
        qname = dns.name.from_text("foo.example")
        r.add_question(qname, dns.rdatatype.A)
        r.add_edns(0, 0, 4096)
        rds = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2")

        def bad():
            r.add_rdataset(dns.renderer.ANSWER, qname, rds)

        self.assertRaises(dns.exception.FormError, bad)

    def test_reservation(self):
        r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
        r.reserve(100)
        assert r.max_size == 412
        r.release_reserved()
        assert r.max_size == 512
        with self.assertRaises(ValueError):
            r.reserve(-1)
        with self.assertRaises(ValueError):
            r.reserve(513)
